{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "26fdd9ed",
   "metadata": {},
   "source": [
    "# Federated Tensorflow Mnist Tutorial With Dummy Envoy Shard Descriptor\n",
    "\n",
    "This tutorial is based on the \"Tensorflow_MNIST\" tutorial but with changes to demostrate how to use the dummy shard descriptor. By default, OpenFL federation created by FedLCM will only use this dummy shard descriptor. And user have to specify how envoy can retreive local data by defining the `DataInterface`.\n",
    "\n",
    "**Understanding of the original \"Tensorflow_MNIST\" is highly recommended!**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0570122",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Install dependencies if not already installed\n",
    "!python -m pip install tensorflow==2.8"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6479191e-bb78-44cd-9462-8e420de6aa63",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06915d29-07bc-47e2-8d4a-ff73428ac104",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "print('TensorFlow', tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "246f9c98",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Connect to the Federation\n",
    "\n",
    "This cell connects this notebook to the Federation.\n",
    "\n",
    "Note *the parameters provided in the cell is for the Jupyter Lab instance deployed along with the director by FedLCM. Change the parameters if necessary if you are running this notebook in other environments.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d657e463",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create a federation\n",
    "from openfl.interface.interactive_api.federation import Federation\n",
    "\n",
    "client_id = 'api'\n",
    "director_node_fqdn = 'director'\n",
    "director_port = 50051\n",
    "cert_chain = '/openfl/workspace/cert/root_ca.crt'\n",
    "api_cert = '/openfl/workspace/cert/notebook.crt'\n",
    "api_private_key = '/openfl/workspace/cert/priv.key'\n",
    "\n",
    "federation = Federation(\n",
    "  client_id=client_id,\n",
    "  director_node_fqdn=director_node_fqdn,\n",
    "  director_port=director_port,\n",
    "  cert_chain=cert_chain,\n",
    "  api_cert=api_cert,\n",
    "  api_private_key=api_private_key\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74c57e78-dfc9-47f7-8a4c-96fdc48f64fc",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Query Datasets from Shard Registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47dcfab3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "shard_registry = federation.get_shard_registry()\n",
    "shard_registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2a6c237",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# First, request a dummy_shard_desc that holds information about the federated dataset \n",
    "# If the shard descriptor is the default dummy one, the shape of the sample and target will be \"1\", which is not the actual data shape\n",
    "dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)\n",
    "dummy_shard_dataset = dummy_shard_desc.get_dataset('train')\n",
    "sample, target = dummy_shard_dataset[0]\n",
    "f\"Sample shape: {sample.shape}, target shape: {target.shape}\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc0dbdbd",
   "metadata": {},
   "source": [
    "## Describing FL experimen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc88700a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from openfl.interface.interactive_api.experiment import TaskInterface\n",
    "from openfl.interface.interactive_api.experiment import ModelInterface\n",
    "from openfl.interface.interactive_api.experiment import FLExperiment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b468ae1",
   "metadata": {},
   "source": [
    "### Register model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06545bbb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Define model\n",
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),\n",
    "    tf.keras.layers.MaxPooling2D((2, 2)),\n",
    "    tf.keras.layers.BatchNormalization(),\n",
    "    tf.keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1)),\n",
    "    tf.keras.layers.MaxPooling2D((2, 2)),\n",
    "    tf.keras.layers.BatchNormalization(),\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(10, activation=None),\n",
    "], name='simplecnn')\n",
    "model.summary()\n",
    "\n",
    "# Define optimizer\n",
    "optimizer = tf.optimizers.Adam(learning_rate=1e-3)\n",
    "\n",
    "# Loss and metrics. These will be used later.\n",
    "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "\n",
    "# Create ModelInterface\n",
    "framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'\n",
    "MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0979470",
   "metadata": {},
   "source": [
    "### Register dataset\n",
    "\n",
    "This is the main difference between this tutorial and the original \"Tensorflow_MNIST\" tutorial. Instead of configuring each Envoy with a specific shard descriptor class, the dummy one is configuraed for them, and we move the data retrieval logic into this subclass from `DataInterface`. For other tasks not using the MNIST dataset, we can write our own local data retrival logic here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8c9eb50",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "import numpy as np\n",
    "from tensorflow.keras.utils import Sequence\n",
    "from openfl.interface.interactive_api.shard_descriptor import ShardDataset\n",
    "\n",
    "from openfl.interface.interactive_api.experiment import DataInterface\n",
    "\n",
    "class MnistShardDataset(ShardDataset):\n",
    "    \"\"\"Mnist Shard dataset class.\"\"\"\n",
    "\n",
    "    def __init__(self, x, y, rank=1, worldsize=1):\n",
    "        \"\"\"Initialize TinyImageNetDataset.\"\"\"\n",
    "        self.rank = rank\n",
    "        self.worldsize = worldsize\n",
    "        self.x = x[self.rank - 1::self.worldsize]\n",
    "        self.y = y[self.rank - 1::self.worldsize]\n",
    "\n",
    "    def __getitem__(self, index: int):\n",
    "        \"\"\"Return an item by the index.\"\"\"\n",
    "        return self.x[index], self.y[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Return the len of the dataset.\"\"\"\n",
    "        return len(self.x)\n",
    "    \n",
    "\n",
    "class DataGenerator(Sequence):\n",
    "\n",
    "    def __init__(self, data_set, batch_size):\n",
    "        self.data_set = data_set\n",
    "        self.batch_size = batch_size\n",
    "        self.indices = np.arange(len(data_set))\n",
    "        self.on_epoch_end()\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.indices) // self.batch_size\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        index = self.indices[index * self.batch_size:(index + 1) * self.batch_size]\n",
    "        batch = [self.indices[k] for k in index]\n",
    "\n",
    "        X, y = self.data_set[batch]\n",
    "        return X, y\n",
    "\n",
    "    def on_epoch_end(self):\n",
    "        np.random.shuffle(self.indices)\n",
    "\n",
    "\n",
    "class MnistFedDataset(DataInterface):\n",
    "\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    @property\n",
    "    def shard_descriptor(self):\n",
    "        return self._shard_descriptor\n",
    "\n",
    "    @shard_descriptor.setter\n",
    "    def shard_descriptor(self, shard_descriptor):\n",
    "        \"\"\"\n",
    "        Describe per-collaborator procedures or sharding.\n",
    "\n",
    "        This method will be called during a collaborator initialization.\n",
    "        Local shard_descriptor will be set by Envoy.\n",
    "        \"\"\"\n",
    "        self._shard_descriptor = shard_descriptor\n",
    "        \n",
    "        (x_train, y_train), (x_test, y_test) = self.download_data()\n",
    "        self.train_set = MnistShardDataset(x_train, y_train)\n",
    "        self.valid_set = MnistShardDataset(x_test, y_test)\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        return self.shard_descriptor[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.shard_descriptor)\n",
    "    \n",
    "    def get_train_loader(self):\n",
    "        \"\"\"\n",
    "        Output of this method will be provided to tasks with optimizer in contract\n",
    "        \"\"\"\n",
    "        if self.kwargs['train_bs']:\n",
    "            batch_size = self.kwargs['train_bs']\n",
    "        else:\n",
    "            batch_size = 32\n",
    "        return DataGenerator(self.train_set, batch_size=batch_size)\n",
    "\n",
    "    def get_valid_loader(self):\n",
    "        \"\"\"\n",
    "        Output of this method will be provided to tasks without optimizer in contract\n",
    "        \"\"\"\n",
    "        if self.kwargs['valid_bs']:\n",
    "            batch_size = self.kwargs['valid_bs']\n",
    "        else:\n",
    "            batch_size = 32\n",
    "        \n",
    "        return DataGenerator(self.valid_set, batch_size=batch_size)\n",
    "\n",
    "    def get_train_data_size(self):\n",
    "        \"\"\"\n",
    "        Information for aggregation\n",
    "        \"\"\"\n",
    "        \n",
    "        return len(self.train_set)\n",
    "\n",
    "    def get_valid_data_size(self):\n",
    "        \"\"\"\n",
    "        Information for aggregation\n",
    "        \"\"\"\n",
    "        return len(self.valid_set)\n",
    "    \n",
    "    def download_data(self):\n",
    "        \"\"\"Download prepared dataset.\"\"\"\n",
    "        local_file_path = 'mnist.npz'\n",
    "        mnist_url = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz'\n",
    "        response = requests.get(mnist_url)\n",
    "        with open(local_file_path, 'wb') as f:\n",
    "            f.write(response.content)\n",
    "\n",
    "        with np.load(local_file_path) as f:\n",
    "            x_train, y_train = f['x_train'], f['y_train']\n",
    "            x_test, y_test = f['x_test'], f['y_test']\n",
    "            x_train = np.reshape(x_train, (-1, 28, 28, 1))\n",
    "            x_test = np.reshape(x_test, (-1, 28, 28, 1))\n",
    "\n",
    "        os.remove(local_file_path)  # remove mnist.npz\n",
    "        print('Mnist data was loaded!')\n",
    "        return (x_train, y_train), (x_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0dfb459",
   "metadata": {},
   "source": [
    "### Create Mnist federated dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af5c4c2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fed_dataset = MnistFedDataset(train_bs=64, valid_bs=512)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "849c165b",
   "metadata": {},
   "source": [
    "## Define and register FL tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9649385",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "\n",
    "\n",
    "TI = TaskInterface()\n",
    "\n",
    "# from openfl.interface.aggregation_functions import AdagradAdaptiveAggregation    # Uncomment this lines to use \n",
    "# agg_fn = AdagradAdaptiveAggregation(model_interface=MI, learning_rate=0.4)       # Adaptive Federated Optimization\n",
    "# @TI.set_aggregation_function(agg_fn)                                             # alghorithm!\n",
    "#                                                                                  # See details in the:\n",
    "#                                                                                  # https://arxiv.org/abs/2003.00295\n",
    "\n",
    "@TI.register_fl_task(model='model', data_loader='train_dataset', device='device', optimizer='optimizer')     \n",
    "def train(model, train_dataset, optimizer, device, loss_fn=loss_fn, warmup=False):\n",
    "    start_time = time.time()\n",
    "\n",
    "    # Iterate over the batches of the dataset.\n",
    "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
    "        with tf.GradientTape() as tape:\n",
    "            logits = model(x_batch_train, training=True)\n",
    "            loss_value = loss_fn(y_batch_train, logits)\n",
    "        grads = tape.gradient(loss_value, model.trainable_weights)\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "\n",
    "        # Update training metric.\n",
    "        train_acc_metric.update_state(y_batch_train, logits)\n",
    "\n",
    "        # Log every 200 batches.\n",
    "        if step % 200 == 0:\n",
    "            print(\n",
    "                \"Training loss (for one batch) at step %d: %.4f\"\n",
    "                % (step, float(loss_value))\n",
    "            )\n",
    "            print(\"Seen so far: %d samples\" % ((step + 1) * 64))\n",
    "        if warmup:\n",
    "            break\n",
    "\n",
    "    # Display metrics at the end of each epoch.\n",
    "    train_acc = train_acc_metric.result()\n",
    "    print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n",
    "\n",
    "    # Reset training metrics at the end of each epoch\n",
    "    train_acc_metric.reset_states()\n",
    "\n",
    "        \n",
    "    return {'train_acc': train_acc,}\n",
    "\n",
    "\n",
    "@TI.register_fl_task(model='model', data_loader='val_dataset', device='device')     \n",
    "def validate(model, val_dataset, device):\n",
    "    # Run a validation loop at the end of each epoch.\n",
    "    for x_batch_val, y_batch_val in val_dataset:\n",
    "        val_logits = model(x_batch_val, training=False)\n",
    "        # Update val metrics\n",
    "        val_acc_metric.update_state(y_batch_val, val_logits)\n",
    "    val_acc = val_acc_metric.result()\n",
    "    val_acc_metric.reset_states()\n",
    "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
    "            \n",
    "    return {'validation_accuracy': val_acc,}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f0ebf2d",
   "metadata": {},
   "source": [
    "## Time to start a federated learning experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d41b7896",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# create an experimnet in federation\n",
    "experiment_name = 'mnist_experiment'\n",
    "fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f2074ca-9dcc-48ad-93fe-be4f65479b7b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# print the default federated learning plan\n",
    "import openfl.native as fx\n",
    "print(fx.get_plan(fl_plan=fl_experiment.plan))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41b44de9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n",
    "fl_experiment.start(model_provider=MI, \n",
    "                   task_keeper=TI,\n",
    "                   data_loader=fed_dataset,\n",
    "                   rounds_to_train=5,\n",
    "                   opt_treatment='CONTINUE_GLOBAL',\n",
    "                   override_config={'aggregator.settings.db_store_rounds': 1, 'compression_pipeline.template': 'openfl.pipelines.KCPipeline', 'compression_pipeline.settings.n_clusters': 2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01fa7cea",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fl_experiment.stream_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba13c3d4-bc2f-4bdb-86f0-5a72d827d9c8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
